import numpy as np

class RevealingGoal:
    def __init__(self, num_agents = 4, grid_size = 5):

        self.num_agents = num_agents
        self.grid_size = grid_size

        self.action_space = 4 + 4
        self.observation_space = [3 * self.num_agents, self.grid_size, self.grid_size]

        self.guess_space_noop = 4 + 1
        self.reveal_space_noop = 4 + 1

        self.map = {
            'agents': np.zeros([self.num_agents, 2], dtype=np.int64),
            'goals': np.zeros([self.num_agents, self.grid_size, self.grid_size], dtype=np.int64),
            'hint': np.ones([self.num_agents, self.grid_size, self.grid_size], dtype=np.int64) * -1
        }
        for i in range(self.num_agents):
            a_loc = np.random.randint(0, self.grid_size, 2)
            g_loc = np.random.randint(0, self.grid_size, 2)
            while (np.min([np.abs(a_loc[0] - g_loc[0]), self.grid_size - np.abs(a_loc[0] - g_loc[0])]) +
                   np.min([np.abs(a_loc[1] - g_loc[1]), self.grid_size - np.abs(a_loc[1] - g_loc[1])])
                   <= 1):
                g_loc = np.random.randint(0, self.grid_size, 2)
            self.map['agents'][i] = a_loc
            self.map['goals'][i, g_loc[0], g_loc[1]] = 1

        self.done = False


    def reset(self):

        self.map = {
            'agents': np.zeros([self.num_agents, 2], dtype=np.int64),
            'goals': np.zeros([self.num_agents, self.grid_size, self.grid_size], dtype=np.int64),
            'hint': np.ones([self.num_agents, self.grid_size, self.grid_size], dtype=np.int64) * -1
        }
        for i in range(self.num_agents):
            a_loc = np.random.randint(0, self.grid_size, 2)
            g_loc = np.random.randint(0, self.grid_size, 2)
            while (np.min([np.abs(a_loc[0] - g_loc[0]), self.grid_size - np.abs(a_loc[0] - g_loc[0])]) +
                   np.min([np.abs(a_loc[1] - g_loc[1]), self.grid_size - np.abs(a_loc[1] - g_loc[1])])
                   <= 1):
                g_loc = np.random.randint(0, self.grid_size, 2)
            self.map['agents'][i] = a_loc
            self.map['goals'][i, g_loc[0], g_loc[1]] = 1

        self.done = False

        obs = self.get_observations()
        available_actions = self.available_actions()

        return obs, available_actions

    def get_observations(self):
        obs = []
        for i in range(self.num_agents):
            o = np.zeros([3 * self.num_agents, self.grid_size, self.grid_size])
            for j in range(self.num_agents):
                agent_map = np.zeros([self.grid_size, self.grid_size])
                agent_map[self.map['agents'][(i+j)%self.num_agents][0], self.map['agents'][(i+j)%self.num_agents][1]] = 1
                o[j] = agent_map
            for j in range(self.num_agents):
                o[j + 3] = self.map['hint'][(i+j)%self.num_agents]
            for j in range(self.num_agents):
                if j == 0:
                    o[j + 6] = np.zeros([self.grid_size, self.grid_size])
                else:
                    o[j + 6] = self.map['goals'][(i+j)%self.num_agents]
            obs.append(o)
        return obs

    def available_actions(self):
        available_actions = []
        for _ in range(self.num_agents):
            available_action = np.ones(self.action_space)
            available_actions.append(available_action)
        return available_actions


    def step(self, actions):

        assert not self.done, "Game Over"

        reward = 0

        for i in range(self.num_agents):
            if actions[i] == 0:  # up
                self.map['agents'][i, 1] = (self.map['agents'][i, 1] + 1) % self.grid_size
            elif actions[i] == 1:  # down
                self.map['agents'][i, 1] = (self.map['agents'][i, 1] - 1) % self.grid_size
            elif actions[i] == 2:  # right
                self.map['agents'][i, 0] = (self.map['agents'][i, 0] + 1) % self.grid_size
            elif actions[i] == 3:  # left
                self.map['agents'][i, 0] = (self.map['agents'][i, 0] - 1) % self.grid_size
            elif actions[i] == 4: # reveal up
                for j in range(self.num_agents):
                    if j != i:
                        self.map['hint'][j, self.map['agents'][i, 0], (self.map['agents'][i, 1] + 1) % self.grid_size]\
                        = self.map['goals'][j, self.map['agents'][i, 0], (self.map['agents'][i, 1] + 1) % self.grid_size]
            elif actions[i] == 5: # reveal down
                for j in range(self.num_agents):
                    if j != i:
                        self.map['hint'][j, self.map['agents'][i, 0], (self.map['agents'][i, 1] - 1) % self.grid_size]\
                        = self.map['goals'][j, self.map['agents'][i, 0], (self.map['agents'][i, 1] - 1) % self.grid_size]
            elif actions[i] == 6: # reveal right
                for j in range(self.num_agents):
                    if j != i:
                        self.map['hint'][j, (self.map['agents'][i, 0] + 1) % self.grid_size, self.map['agents'][i, 1]]\
                        = self.map['goals'][j, (self.map['agents'][i, 0] + 1) % self.grid_size, self.map['agents'][i, 1]]
            elif actions[i] == 7: # reveal left
                for j in range(self.num_agents):
                    if j != i:
                        self.map['hint'][j, (self.map['agents'][i, 0] - 1) % self.grid_size, self.map['agents'][i, 1]]\
                        = self.map['goals'][j, (self.map['agents'][i, 0] - 1) % self.grid_size, self.map['agents'][i, 1]]
            else:
                raise ValueError('Invalid action', actions[i])

            if self.map['goals'][i, self.map['agents'][i, 0], self.map['agents'][i, 1]] == 1:
                reward += 1

                g_loc = np.random.randint(0, self.grid_size, 2)
                while (np.min([np.abs(self.map['agents'][i, 0] - g_loc[0]), self.grid_size - np.abs(self.map['agents'][i, 0] - g_loc[0])]) +
                       np.min([np.abs(self.map['agents'][i, 1] - g_loc[1]), self.grid_size - np.abs(self.map['agents'][i, 1] - g_loc[1])])
                       <= 1):
                    g_loc = np.random.randint(0, self.grid_size, 2)
                self.map['goals'][i] = np.zeros([self.grid_size, self.grid_size])
                self.map['goals'][i, g_loc[0], g_loc[1]] = 1
                self.map['hint'][i] = np.ones([self.grid_size, self.grid_size]) * -1

        obs = self.get_observations()
        available_actions = self.available_actions()

        return obs, [reward]*self.num_agents, self.done, available_actions
